{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pipeline\n",
    "Pipeline是Huggingface的一个基本工具，可以理解为一个端到端(end-to-end)的一键调用Transformer模型的工具。\n",
    "\n",
    "It connects a model with its necessary preprocessing and postprocessing steps, allowing us to directly input any text and get an intelligible answer.\n",
    "\n",
    "给定一个任务之后，pipeline会自动调用一个预训练好的模型，然后根据你给的输入执行下面三个步骤：\n",
    "1. 预处理输入文本，让它可被模型读取\n",
    "2. 模型处理\n",
    "3. 模型输出的后处理，让预测结果可读\n",
    "\n",
    "一个例子如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "clf = pipeline('sentiment-analysis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'POSITIVE', 'score': 0.9998709559440613}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf('Haha, today is a nice day!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "还可以直接接受多个句子，一起预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'POSITIVE', 'score': 0.9998160600662231},\n",
       " {'label': 'POSITIVE', 'score': 0.9998552799224854},\n",
       " {'label': 'NEGATIVE', 'score': 0.999782383441925}]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf(['good','nice','bad'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "pipeline支持的task包括：\n",
    "\n",
    "- \"feature-extraction\": will return a FeatureExtractionPipeline.\n",
    "- \"text-classification\": will return a TextClassificationPipeline.\n",
    "- \"sentiment-analysis\": (alias of \"text-classification\") will return a TextClassificationPipeline.\n",
    "- \"token-classification\": will return a TokenClassificationPipeline.\n",
    "- \"ner\" (alias of \"token-classification\"): will return a TokenClassificationPipeline.\n",
    "- \"question-answering\": will return a QuestionAnsweringPipeline.\n",
    "- \"fill-mask\": will return a FillMaskPipeline.\n",
    "- \"summarization\": will return a SummarizationPipeline.\n",
    "- \"translation_xx_to_yy\": will return a TranslationPipeline.\n",
    "- \"text2text-generation\": will return a Text2TextGenerationPipeline.\n",
    "- \"text-generation\": will return a TextGenerationPipeline.\n",
    "- \"zero-shot-classification:: will return a ZeroShotClassificationPipeline.\n",
    "- \"conversational\": will return a ConversationalPipeline."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Have a try: Zero-shot-classification\n",
    "零样本学习，就是训练一个可以预测任何标签的模型，这些标签可以不出现在训练集中。\n",
    "\n",
    "一种零样本学习的方法，就是通过NLI（文本蕴含）任务，训练一个推理模型，比如这个例子：\n",
    "```python\n",
    "premise = 'Who are you voting for in 2020?'\n",
    "hypothesis = 'This text is about politics.'\n",
    "```\n",
    "上面有一个前提(premise)和一个假设(hypothesis)，NLI任务就是去预测，在这个premise下，hypothesis是否成立。\n",
    "\n",
    "通过这样的训练，我们可以直接把hypothesis中的politics换成其他词儿，就可以实现zero-shot-learning了。\n",
    "\n",
    "NLI任务的解释：it classifies if two sentences are logically linked across three labels (contradiction, neutral, entailment) — a task also called natural language inference.\n",
    "\n",
    "参考阅读：\n",
    "- 官方 Zero-shot-classification Pipeline文档：https://huggingface.co/transformers/main_classes/pipelines.html#transformers.ZeroShotClassificationPipeline\n",
    "- 零样本学习简介：https://mp.weixin.qq.com/s/6aBzR0O3pwA8-btsuDX82g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = pipeline('zero-shot-classification')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'sequence': 'A helicopter is flying in the sky',\n",
       "  'labels': ['machine', 'animal'],\n",
       "  'scores': [0.9938627481460571, 0.006137280724942684]},\n",
       " {'sequence': 'A bird is flying in the sky',\n",
       "  'labels': ['animal', 'machine'],\n",
       "  'scores': [0.9987970590591431, 0.0012029369827359915]}]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf(sequences=[\"A helicopter is flying in the sky\",\n",
    "               \"A bird is flying in the sky\"],\n",
    "    candidate_labels=['animal','machine'])  # labels可以完全自定义"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Have a try: Text Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d84006ae024439fb571c12e15825b9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=357.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6e2a89ad3b4447582c1446c10cfd9f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=616.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "generator = pipeline('text-generation', model='liam168/chat-DialoGPT-small-zh')  # 默认使用gpt2，也可以指定模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': '上午上班吧'}]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator('上午')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Have a try: Mask Filling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03b6c5c4b57c4e76917967705df678cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "unmasker = pipeline('fill-mask')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'sequence': 'What the heck?',\n",
       "  'score': 0.3783760964870453,\n",
       "  'token': 17835,\n",
       "  'token_str': ' heck'},\n",
       " {'sequence': 'What the hell?',\n",
       "  'score': 0.32931089401245117,\n",
       "  'token': 7105,\n",
       "  'token_str': ' hell'},\n",
       " {'sequence': 'What the fuck?',\n",
       "  'score': 0.14645449817180634,\n",
       "  'token': 26536,\n",
       "  'token_str': ' fuck'}]"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unmasker('What the <mask>?', top_k=3)  # 注意不同的模型，MASK token可能不一样，不一定都是 <mask>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 更多的Task，见官方教程\n",
    "https://huggingface.co/course/chapter1/3?fw=pt"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
